import matplotlib.pyplot as plt
import numpy as np
from neuralnetwork import NeuralNetwork
import scipy

input_n = 784
hidden_n =100
output_n = 10
lr = 0.3
net = NeuralNetwork(input_n,hidden_n,output_n,lr)

with open("minst_dataset\mnist_train.csv","r") as f:
    # 获取和格式化input
    # i = 0
    while True:
        # i+=1
        line = f.readline() 
        if line == "":
            break

        line = np.asfarray(line.split(','))
        inputdata = line[1:]/255.0 *0.99 + 0.001 #格式化到（0，1）
        # 获取和格式化target
        target = np.zeros(output_n) + 0.001
        target[int(line[0])] = 0.99
        # print(target)
        # 输出图像


        net.train(inputdata,target)

        # img = inputdata.reshape(28,28)
        # plt.imshow(img,cmap="grey")
        # plt.show()
import pickle

with open("./model.pkl","wb") as f:
    pickle.dump(net,f)

